{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# How to further train a pre-trained model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will demonstrate how to freeze some or all of the layers of a pre-trained model and continue training using a new fully-connected set of layers and data with a different format."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Imports & Settings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 116,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.datasets import load_files       \n",
    "from keras.utils import np_utils\n",
    "import numpy as np\n",
    "from pathlib import Path\n",
    "\n",
    "from keras.datasets import cifar10\n",
    "from keras.utils import to_categorical\n",
    "from keras.preprocessing.image import ImageDataGenerator\n",
    "from keras.applications.vgg16 import VGG16\n",
    "from keras.layers import Dense, Flatten, Dropout\n",
    "from keras.models import Sequential, Model \n",
    "from keras.callbacks import ModelCheckpoint, TensorBoard\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load Dog Dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Before running the code cell below, download the dataset of dog images [here](https://s3-us-west-1.amazonaws.com/udacity-aind/dog-project/dogImages.zip)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "metadata": {},
   "outputs": [],
   "source": [
    "(X_train, y_train), (X_test, y_test) = cifar10.load_data()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "metadata": {},
   "outputs": [],
   "source": [
    "cifar10_labels = {0: 'airplane',\n",
    "                  1: 'automobile',\n",
    "                  2: 'bird',\n",
    "                  3: 'cat',\n",
    "                  4: 'deer',\n",
    "                  5: 'dog',\n",
    "                  6: 'frog',\n",
    "                  7: 'horse',\n",
    "                  8: 'ship',\n",
    "                  9: 'truck'}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_classes = len(cifar10_labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 71,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_train = to_categorical(y_train, num_classes)\n",
    "y_test = to_categorical(y_test, num_classes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 72,
   "metadata": {},
   "outputs": [],
   "source": [
    "# X_train, X_valid = X_train[5000:], X_train[:5000]\n",
    "# y_train, y_valid = y_train[5000:], y_train[:5000]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Obtain the VGG-16 Bottleneck Features"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We use the VGG16 weights, pre-trained on ImageNet with the much smaller 32 x 32 CIFAR10 data. Note that we indicate the new input size upon import and set all layers to not trainable:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 118,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "input_7 (InputLayer)         (None, 32, 32, 3)         0         \n",
      "_________________________________________________________________\n",
      "block1_conv1 (Conv2D)        (None, 32, 32, 64)        1792      \n",
      "_________________________________________________________________\n",
      "block1_conv2 (Conv2D)        (None, 32, 32, 64)        36928     \n",
      "_________________________________________________________________\n",
      "block1_pool (MaxPooling2D)   (None, 16, 16, 64)        0         \n",
      "_________________________________________________________________\n",
      "block2_conv1 (Conv2D)        (None, 16, 16, 128)       73856     \n",
      "_________________________________________________________________\n",
      "block2_conv2 (Conv2D)        (None, 16, 16, 128)       147584    \n",
      "_________________________________________________________________\n",
      "block2_pool (MaxPooling2D)   (None, 8, 8, 128)         0         \n",
      "_________________________________________________________________\n",
      "block3_conv1 (Conv2D)        (None, 8, 8, 256)         295168    \n",
      "_________________________________________________________________\n",
      "block3_conv2 (Conv2D)        (None, 8, 8, 256)         590080    \n",
      "_________________________________________________________________\n",
      "block3_conv3 (Conv2D)        (None, 8, 8, 256)         590080    \n",
      "_________________________________________________________________\n",
      "block3_pool (MaxPooling2D)   (None, 4, 4, 256)         0         \n",
      "_________________________________________________________________\n",
      "block4_conv1 (Conv2D)        (None, 4, 4, 512)         1180160   \n",
      "_________________________________________________________________\n",
      "block4_conv2 (Conv2D)        (None, 4, 4, 512)         2359808   \n",
      "_________________________________________________________________\n",
      "block4_conv3 (Conv2D)        (None, 4, 4, 512)         2359808   \n",
      "_________________________________________________________________\n",
      "block4_pool (MaxPooling2D)   (None, 2, 2, 512)         0         \n",
      "_________________________________________________________________\n",
      "block5_conv1 (Conv2D)        (None, 2, 2, 512)         2359808   \n",
      "_________________________________________________________________\n",
      "block5_conv2 (Conv2D)        (None, 2, 2, 512)         2359808   \n",
      "_________________________________________________________________\n",
      "block5_conv3 (Conv2D)        (None, 2, 2, 512)         2359808   \n",
      "_________________________________________________________________\n",
      "block5_pool (MaxPooling2D)   (None, 1, 1, 512)         0         \n",
      "=================================================================\n",
      "Total params: 14,714,688\n",
      "Trainable params: 14,714,688\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "vgg16 = VGG16(include_top=False, input_shape =X_train.shape[1:])\n",
    "vgg16.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Freeze model layers"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Selectively freeze layers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 120,
   "metadata": {},
   "outputs": [],
   "source": [
    "for layer in vgg16.layers:\n",
    "    layer.trainable = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 98,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "input_6 (InputLayer)         (None, 32, 32, 3)         0         \n",
      "_________________________________________________________________\n",
      "block1_conv1 (Conv2D)        (None, 32, 32, 64)        1792      \n",
      "_________________________________________________________________\n",
      "block1_conv2 (Conv2D)        (None, 32, 32, 64)        36928     \n",
      "_________________________________________________________________\n",
      "block1_pool (MaxPooling2D)   (None, 16, 16, 64)        0         \n",
      "_________________________________________________________________\n",
      "block2_conv1 (Conv2D)        (None, 16, 16, 128)       73856     \n",
      "_________________________________________________________________\n",
      "block2_conv2 (Conv2D)        (None, 16, 16, 128)       147584    \n",
      "_________________________________________________________________\n",
      "block2_pool (MaxPooling2D)   (None, 8, 8, 128)         0         \n",
      "_________________________________________________________________\n",
      "block3_conv1 (Conv2D)        (None, 8, 8, 256)         295168    \n",
      "_________________________________________________________________\n",
      "block3_conv2 (Conv2D)        (None, 8, 8, 256)         590080    \n",
      "_________________________________________________________________\n",
      "block3_conv3 (Conv2D)        (None, 8, 8, 256)         590080    \n",
      "_________________________________________________________________\n",
      "block3_pool (MaxPooling2D)   (None, 4, 4, 256)         0         \n",
      "_________________________________________________________________\n",
      "block4_conv1 (Conv2D)        (None, 4, 4, 512)         1180160   \n",
      "_________________________________________________________________\n",
      "block4_conv2 (Conv2D)        (None, 4, 4, 512)         2359808   \n",
      "_________________________________________________________________\n",
      "block4_conv3 (Conv2D)        (None, 4, 4, 512)         2359808   \n",
      "_________________________________________________________________\n",
      "block4_pool (MaxPooling2D)   (None, 2, 2, 512)         0         \n",
      "_________________________________________________________________\n",
      "block5_conv1 (Conv2D)        (None, 2, 2, 512)         2359808   \n",
      "_________________________________________________________________\n",
      "block5_conv2 (Conv2D)        (None, 2, 2, 512)         2359808   \n",
      "_________________________________________________________________\n",
      "block5_conv3 (Conv2D)        (None, 2, 2, 512)         2359808   \n",
      "_________________________________________________________________\n",
      "block5_pool (MaxPooling2D)   (None, 1, 1, 512)         0         \n",
      "=================================================================\n",
      "Total params: 14,714,688\n",
      "Trainable params: 0\n",
      "Non-trainable params: 14,714,688\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "vgg16.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Add new layers to model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We use Keras’ functional API to define the vgg16 output as input into a new set of fully-connected layers like so:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 99,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Adding custom Layers \n",
    "x = vgg16.output\n",
    "x = Flatten()(x)\n",
    "x = Dense(512, activation=\"relu\")(x)\n",
    "x = Dropout(0.5)(x)\n",
    "x = Dense(256, activation=\"relu\")(x)\n",
    "predictions = Dense(10, activation=\"softmax\")(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We define a new model in terms of inputs and output, and proceed from there on as before:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 100,
   "metadata": {},
   "outputs": [],
   "source": [
    "transfer_model = Model(inputs = vgg16.input, \n",
    "                       outputs = predictions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 101,
   "metadata": {},
   "outputs": [],
   "source": [
    "transfer_model.compile(loss = 'categorical_crossentropy', \n",
    "                       optimizer = 'Adam', \n",
    "                       metrics=[\"accuracy\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 102,
   "metadata": {},
   "outputs": [],
   "source": [
    "validation_split = .1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We use a more elaborate ImageDataGenerator that also defines a validation_split:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 103,
   "metadata": {},
   "outputs": [],
   "source": [
    "datagen = ImageDataGenerator(\n",
    "        rescale=1. / 255,\n",
    "        horizontal_flip=True,\n",
    "        fill_mode='nearest',\n",
    "        zoom_range=0.1,\n",
    "        width_shift_range=0.1,\n",
    "        height_shift_range=0.1,\n",
    "        rotation_range=30,\n",
    "        validation_split=validation_split)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 104,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size =32\n",
    "epochs = 10"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We define both train- and validation generators for the fit method:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 105,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_generator = datagen.flow(X_train, \n",
    "                               y_train, \n",
    "                               subset='training')\n",
    "val_generator = datagen.flow(X_train, \n",
    "                             y_train, \n",
    "                             subset='validation')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 108,
   "metadata": {},
   "outputs": [],
   "source": [
    "vgg16_path = 'models/cifar10.transfer.vgg16.weights.best.hdf5'\n",
    "checkpointer = ModelCheckpoint(filepath=vgg16_path, \n",
    "                               verbose=1, \n",
    "                               save_best_only=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And now we proceed to train the model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 109,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/10\n",
      "1562/1562 [==============================] - 16s 10ms/step - loss: 1.5325 - acc: 0.4553 - val_loss: 1.3096 - val_acc: 0.5438\n",
      "\n",
      "Epoch 00001: val_loss improved from inf to 1.30961, saving model to models/cifar10.transfer.vgg16.weights.best.hdf5\n",
      "Epoch 2/10\n",
      "1562/1562 [==============================] - 15s 10ms/step - loss: 1.3717 - acc: 0.5138 - val_loss: 1.2726 - val_acc: 0.5532\n",
      "\n",
      "Epoch 00002: val_loss improved from 1.30961 to 1.27260, saving model to models/cifar10.transfer.vgg16.weights.best.hdf5\n",
      "Epoch 3/10\n",
      "1562/1562 [==============================] - 15s 10ms/step - loss: 1.3253 - acc: 0.5339 - val_loss: 1.2515 - val_acc: 0.5591\n",
      "\n",
      "Epoch 00003: val_loss improved from 1.27260 to 1.25149, saving model to models/cifar10.transfer.vgg16.weights.best.hdf5\n",
      "Epoch 4/10\n",
      "1562/1562 [==============================] - 15s 10ms/step - loss: 1.3060 - acc: 0.5410 - val_loss: 1.2249 - val_acc: 0.5715\n",
      "\n",
      "Epoch 00004: val_loss improved from 1.25149 to 1.22492, saving model to models/cifar10.transfer.vgg16.weights.best.hdf5\n",
      "Epoch 5/10\n",
      "1562/1562 [==============================] - 15s 10ms/step - loss: 1.2715 - acc: 0.5509 - val_loss: 1.2011 - val_acc: 0.5766\n",
      "\n",
      "Epoch 00005: val_loss improved from 1.22492 to 1.20108, saving model to models/cifar10.transfer.vgg16.weights.best.hdf5\n",
      "Epoch 6/10\n",
      "1562/1562 [==============================] - 15s 10ms/step - loss: 1.2526 - acc: 0.5595 - val_loss: 1.1950 - val_acc: 0.5868\n",
      "\n",
      "Epoch 00006: val_loss improved from 1.20108 to 1.19496, saving model to models/cifar10.transfer.vgg16.weights.best.hdf5\n",
      "Epoch 7/10\n",
      "1562/1562 [==============================] - 15s 10ms/step - loss: 1.2387 - acc: 0.5635 - val_loss: 1.1926 - val_acc: 0.5783\n",
      "\n",
      "Epoch 00007: val_loss improved from 1.19496 to 1.19262, saving model to models/cifar10.transfer.vgg16.weights.best.hdf5\n",
      "Epoch 8/10\n",
      "1562/1562 [==============================] - 15s 10ms/step - loss: 1.2249 - acc: 0.5678 - val_loss: 1.1779 - val_acc: 0.5948\n",
      "\n",
      "Epoch 00008: val_loss improved from 1.19262 to 1.17794, saving model to models/cifar10.transfer.vgg16.weights.best.hdf5\n",
      "Epoch 9/10\n",
      "1562/1562 [==============================] - 15s 10ms/step - loss: 1.2177 - acc: 0.5700 - val_loss: 1.1687 - val_acc: 0.5927\n",
      "\n",
      "Epoch 00009: val_loss improved from 1.17794 to 1.16873, saving model to models/cifar10.transfer.vgg16.weights.best.hdf5\n",
      "Epoch 10/10\n",
      "1562/1562 [==============================] - 15s 10ms/step - loss: 1.2042 - acc: 0.5739 - val_loss: 1.1826 - val_acc: 0.5862\n",
      "\n",
      "Epoch 00010: val_loss did not improve from 1.16873\n"
     ]
    }
   ],
   "source": [
    "transfer_model.fit_generator(train_generator,\n",
    "                             steps_per_epoch=X_train.shape[0] // batch_size,\n",
    "                             epochs=epochs,\n",
    "                             validation_data=val_generator,\n",
    "                             validation_steps=(X_train.shape[0] * .2) // batch_size,\n",
    "                             callbacks=[checkpointer],\n",
    "                             verbose=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 111,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load the weights that yielded the best validation accuracy\n",
    "transfer_model.load_weights(vgg16_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 112,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "10000/10000 [==============================] - 2s 236us/step\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "0.3587"
      ]
     },
     "execution_count": 112,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "transfer_model.evaluate(X_test, y_test)[1]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Test Set Classification Accuracy"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "10 epochs lead to a mediocre test accuracy of 35.87% because the assumption that image features translate to so much smaller images is somewhat questionable but it serves to illustrate the workflow."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 114,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get index of predicted dog breed for each image in test set\n",
    "vgg16_predictions = np.argmax(transfer_model.predict(X_test), axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 115,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Test accuracy: 0.3587%\n"
     ]
    }
   ],
   "source": [
    "test_accuracy = np.sum(vgg16_predictions==np.argmax(y_test, axis=1))/len(vgg16_predictions)\n",
    "print('\\nTest accuracy: %.4f%%' % test_accuracy)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": true,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {
    "height": "calc(100% - 180px)",
    "left": "10px",
    "top": "150px",
    "width": "318.55px"
   },
   "toc_section_display": true,
   "toc_window_display": true
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
